Open In Colab

This notebook will contain the description and code for various type of Autoencoders.

Dataset¶

we are going to use an anime face dataset and our aim is to generate or reproduce anime faces

In [1]:
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
In [2]:
''' Link to explain how to download Datasets from kaggle https://www.kaggle.com/general/74235'''
!pip install -q kaggle
!mkdir ~/.kaggle
!cp '/content/drive/My Drive/Kaggle/kaggle.json' ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
mkdir: cannot create directory ‘/root/.kaggle’: File exists
In [3]:
%%time
!kaggle datasets download -d splcher/animefacedataset -p dataset
!unzip -q dataset/animefacedataset.zip -d dataset/animefacedataset
!rm dataset/animefacedataset.zip
Downloading animefacedataset.zip to dataset
 97% 383M/395M [00:02<00:00, 189MB/s]
100% 395M/395M [00:02<00:00, 145MB/s]
replace dataset/animefacedataset/images/0_2000.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: N
CPU times: user 86.5 ms, sys: 31.2 ms, total: 118 ms
Wall time: 10.8 s
In [4]:
import os

dataset_dir = "dataset/animefacedataset/images"
image_files = [os.path.join(dataset_dir, x) for x in os.listdir(dataset_dir)] 
len(image_files)
Out[4]:
63565
In [5]:
from matplotlib import pyplot as plt
import numpy as np
import math
import cv2

def plot_images(images):
  n_col = 8
  n_row = int(math.ceil(len(images) / n_col))
  _, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
  axs = axs.flatten()
  for img, ax in zip(images, axs):
      if os.path.exists(img):
          img = cv2.imread(img)
          img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
      ax.imshow(img)
  plt.show()

from mpl_toolkits.axes_grid1 import ImageGrid
def plot_images(images, n_col=8):
  n_row = int(math.ceil(len(images) / n_col))
  fig = plt.figure(figsize=(12., 12.))
  grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(n_row, n_col),  # creates 2x2 grid of axes
                 axes_pad=0.0,  # pad between axes in inch.
                 )

  for ax, img in zip(grid, images):
      # Iterating over the grid returns the Axes.
      if type(img) == str and os.path.exists(img):
          img = cv2.imread(img)
          img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
          img = cv2.resize(img, (64, 64))  # Reshaping for visualization
      ax.imshow(img)
  plt.show()
plot_images(image_files[0:16])
In [6]:
from sklearn.model_selection import train_test_split
images_files_train, images_files_test = train_test_split(image_files, test_size=0.3, shuffle=True)

print("Train:", len(images_files_train))
print("Test:", len(images_files_test))
Train: 44495
Test: 19070
In [7]:
def read_image_file(imgfile):
  img = cv2.imread(imgfile)
  img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  img = cv2.resize(img, (64, 64))  # Reshaping for visualization
  return img.astype(np.uint8)
In [8]:
images_train = np.array([read_image_file(x) for x in images_files_train])
images_train.shape
Out[8]:
(44495, 64, 64, 3)
In [9]:
images_test = np.array([read_image_file(x) for x in images_files_test])
images_test.shape
Out[9]:
(19070, 64, 64, 3)
In [10]:
images_shape = images_test[0].shape
total_pixels = np.size(images_test[0])
images_shape, total_pixels
Out[10]:
((64, 64, 3), 12288)

Autoencoder - DNN¶

These type of autoencoders contains dense layers as encoder and decoder

Lets try to build an autoencoder using only dense layers to reproduce same input image

Training¶

In [67]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

model_file = 'model_ae_dnn.h5'

model = keras.Sequential(name="my_sequential")
model.add(keras.Input(shape=images_shape, dtype=tf.int8))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation="relu", name="encoder_layer_1"))
model.add(layers.Dense(64, activation="relu", name="encoder_layer_2"))
model.add(layers.Dense(32, activation="relu", name="encoder_layer_3"))
model.add(layers.Dense(16, activation="relu", name="encoder_layer_4"))

model.add(layers.Dense(8, name="code"))

model.add(layers.Dense(16, activation="relu", name="decoder_layer_1"))
model.add(layers.Dense(32, activation="relu", name="decoder_layer_2"))
model.add(layers.Dense(64, activation="relu", name="decoder_layer_3"))
model.add(layers.Dense(128, activation="relu", name="decoder_layer_4"))

model.add(layers.Dense(total_pixels, activation="relu", name="final_layer"))
model.add(layers.Reshape(images_shape))

checkpoint = ModelCheckpoint(model_file, verbose=0, monitor='val_loss', save_best_only=True, mode='auto')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# tf.keras.losses.MeanAbsoluteError()
# tf.keras.losses.MeanSquaredError()
# tf.keras.losses.kullback_leibler_divergence()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=tf.keras.losses.MeanSquaredError(), 
              metrics=['mse']
              )
model.summary()
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
Model: "my_sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_9 (Flatten)          (None, 12288)             0         
_________________________________________________________________
encoder_layer_1 (Dense)      (None, 128)               1572992   
_________________________________________________________________
encoder_layer_2 (Dense)      (None, 64)                8256      
_________________________________________________________________
encoder_layer_3 (Dense)      (None, 32)                2080      
_________________________________________________________________
encoder_layer_4 (Dense)      (None, 16)                528       
_________________________________________________________________
code (Dense)                 (None, 8)                 136       
_________________________________________________________________
decoder_layer_1 (Dense)      (None, 16)                144       
_________________________________________________________________
decoder_layer_2 (Dense)      (None, 32)                544       
_________________________________________________________________
decoder_layer_3 (Dense)      (None, 64)                2112      
_________________________________________________________________
decoder_layer_4 (Dense)      (None, 128)               8320      
_________________________________________________________________
final_layer (Dense)          (None, 12288)             1585152   
_________________________________________________________________
reshape_9 (Reshape)          (None, 64, 64, 3)         0         
=================================================================
Total params: 3,180,264
Trainable params: 3,180,264
Non-trainable params: 0
_________________________________________________________________
In [68]:
%%time
model.fit(images_train, images_train, batch_size=16, epochs=500, validation_split=0.2, callbacks=[checkpoint, early_stopping], shuffle=True)
model.save(model_file) # Save Best model to disk
Epoch 1/500
2225/2225 [==============================] - 16s 7ms/step - loss: 7745.4814 - mse: 7745.4814 - val_loss: 5789.5469 - val_mse: 5789.5469
Epoch 2/500
2225/2225 [==============================] - 15s 7ms/step - loss: 5094.6670 - mse: 5094.6670 - val_loss: 4754.9946 - val_mse: 4754.9946
Epoch 3/500
2225/2225 [==============================] - 15s 7ms/step - loss: 4244.2397 - mse: 4244.2397 - val_loss: 4173.7920 - val_mse: 4173.7920
Epoch 4/500
2225/2225 [==============================] - 15s 7ms/step - loss: 3777.8196 - mse: 3777.8196 - val_loss: 3882.5774 - val_mse: 3882.5774
Epoch 5/500
2225/2225 [==============================] - 15s 7ms/step - loss: 3416.8728 - mse: 3416.8728 - val_loss: 3642.4468 - val_mse: 3642.4468
Epoch 6/500
2225/2225 [==============================] - 15s 7ms/step - loss: 3163.1895 - mse: 3163.1895 - val_loss: 3392.8254 - val_mse: 3392.8254
Epoch 7/500
2225/2225 [==============================] - 15s 7ms/step - loss: 3011.4729 - mse: 3011.4729 - val_loss: 3302.9238 - val_mse: 3302.9238
Epoch 8/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2893.6626 - mse: 2893.6626 - val_loss: 3230.7539 - val_mse: 3230.7539
Epoch 9/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2801.2739 - mse: 2801.2739 - val_loss: 3111.6794 - val_mse: 3111.6794
Epoch 10/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2709.9768 - mse: 2709.9768 - val_loss: 3024.9519 - val_mse: 3024.9519
Epoch 11/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2628.3242 - mse: 2628.3242 - val_loss: 2966.7283 - val_mse: 2966.7283
Epoch 12/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2557.0449 - mse: 2557.0449 - val_loss: 2912.2588 - val_mse: 2912.2588
Epoch 13/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2502.5159 - mse: 2502.5159 - val_loss: 2850.7229 - val_mse: 2850.7229
Epoch 14/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2455.8535 - mse: 2455.8535 - val_loss: 2820.7717 - val_mse: 2820.7717
Epoch 15/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2415.6208 - mse: 2415.6208 - val_loss: 2782.8862 - val_mse: 2782.8862
Epoch 16/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2378.2502 - mse: 2378.2502 - val_loss: 2762.5403 - val_mse: 2762.5403
Epoch 17/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2352.1372 - mse: 2352.1372 - val_loss: 2739.2517 - val_mse: 2739.2517
Epoch 18/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2318.7959 - mse: 2318.7959 - val_loss: 2712.8945 - val_mse: 2712.8945
Epoch 19/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2296.4275 - mse: 2296.4275 - val_loss: 2706.2317 - val_mse: 2706.2317
Epoch 20/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2276.7935 - mse: 2276.7935 - val_loss: 2682.8230 - val_mse: 2682.8230
Epoch 21/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2259.1987 - mse: 2259.1987 - val_loss: 2671.7095 - val_mse: 2671.7095
Epoch 22/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2243.9900 - mse: 2243.9900 - val_loss: 2656.2014 - val_mse: 2656.2014
Epoch 23/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2225.9990 - mse: 2225.9990 - val_loss: 2628.5667 - val_mse: 2628.5667
Epoch 24/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2201.2419 - mse: 2201.2419 - val_loss: 2633.5908 - val_mse: 2633.5908
Epoch 25/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2187.1528 - mse: 2187.1528 - val_loss: 2603.9958 - val_mse: 2603.9958
Epoch 26/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2167.0247 - mse: 2167.0247 - val_loss: 2579.0828 - val_mse: 2579.0828
Epoch 27/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2147.0837 - mse: 2147.0837 - val_loss: 2565.5955 - val_mse: 2565.5955
Epoch 28/500
2225/2225 [==============================] - 14s 7ms/step - loss: 2131.2476 - mse: 2131.2476 - val_loss: 2553.0916 - val_mse: 2553.0916
Epoch 29/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2125.1418 - mse: 2125.1418 - val_loss: 2546.6794 - val_mse: 2546.6794
Epoch 30/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2113.3552 - mse: 2113.3552 - val_loss: 2535.7832 - val_mse: 2535.7832
Epoch 31/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2106.2405 - mse: 2106.2405 - val_loss: 2535.9009 - val_mse: 2535.9009
Epoch 32/500
2225/2225 [==============================] - 14s 7ms/step - loss: 2100.7708 - mse: 2100.7708 - val_loss: 2521.9089 - val_mse: 2521.9089
Epoch 33/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2093.4861 - mse: 2093.4861 - val_loss: 2529.3059 - val_mse: 2529.3059
Epoch 34/500
2225/2225 [==============================] - 14s 7ms/step - loss: 2086.1904 - mse: 2086.1904 - val_loss: 2506.7351 - val_mse: 2506.7351
Epoch 35/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2084.8545 - mse: 2084.8545 - val_loss: 2516.1262 - val_mse: 2516.1262
Epoch 36/500
2225/2225 [==============================] - 14s 7ms/step - loss: 2077.9331 - mse: 2077.9331 - val_loss: 2503.3787 - val_mse: 2503.3787
Epoch 37/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2072.5806 - mse: 2072.5806 - val_loss: 2494.9321 - val_mse: 2494.9321
Epoch 38/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2067.7234 - mse: 2067.7234 - val_loss: 2494.3889 - val_mse: 2494.3889
Epoch 39/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2062.2871 - mse: 2062.2871 - val_loss: 2483.0850 - val_mse: 2483.0850
Epoch 40/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2057.3281 - mse: 2057.3281 - val_loss: 2478.1245 - val_mse: 2478.1245
Epoch 41/500
2225/2225 [==============================] - 14s 7ms/step - loss: 2055.6375 - mse: 2055.6375 - val_loss: 2477.5952 - val_mse: 2477.5952
Epoch 42/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2047.7416 - mse: 2047.7416 - val_loss: 2474.7520 - val_mse: 2474.7520
Epoch 43/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2044.8254 - mse: 2044.8254 - val_loss: 2487.7310 - val_mse: 2487.7310
Epoch 44/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2041.1714 - mse: 2041.1714 - val_loss: 2473.2947 - val_mse: 2473.2947
Epoch 45/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2038.4335 - mse: 2038.4335 - val_loss: 2460.0046 - val_mse: 2460.0046
Epoch 46/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2036.6165 - mse: 2036.6165 - val_loss: 2459.0994 - val_mse: 2459.0994
Epoch 47/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2034.9250 - mse: 2034.9250 - val_loss: 2459.0425 - val_mse: 2459.0425
Epoch 48/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2031.9937 - mse: 2031.9937 - val_loss: 2445.1523 - val_mse: 2445.1523
Epoch 49/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2029.2157 - mse: 2029.2157 - val_loss: 2448.8923 - val_mse: 2448.8923
Epoch 50/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2027.5762 - mse: 2027.5762 - val_loss: 2450.2930 - val_mse: 2450.2930
Epoch 51/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2024.1421 - mse: 2024.1421 - val_loss: 2455.2170 - val_mse: 2455.2170
Epoch 52/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2020.7971 - mse: 2020.7971 - val_loss: 2451.8428 - val_mse: 2451.8428
Epoch 53/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2016.3461 - mse: 2016.3461 - val_loss: 2430.4321 - val_mse: 2430.4321
Epoch 54/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2011.9696 - mse: 2011.9696 - val_loss: 2423.0098 - val_mse: 2423.0098
Epoch 55/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2008.6443 - mse: 2008.6443 - val_loss: 2413.8279 - val_mse: 2413.8279
Epoch 56/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2006.3228 - mse: 2006.3228 - val_loss: 2420.9802 - val_mse: 2420.9802
Epoch 57/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2005.4141 - mse: 2005.4141 - val_loss: 2434.0698 - val_mse: 2434.0698
Epoch 58/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2002.9851 - mse: 2002.9851 - val_loss: 2432.4451 - val_mse: 2432.4451
Epoch 59/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2001.9885 - mse: 2001.9885 - val_loss: 2427.0110 - val_mse: 2427.0110
Epoch 60/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1999.6281 - mse: 1999.6281 - val_loss: 2413.1631 - val_mse: 2413.1631
Epoch 61/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1999.6019 - mse: 1999.6019 - val_loss: 2405.9392 - val_mse: 2405.9392
Epoch 62/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1997.6655 - mse: 1997.6655 - val_loss: 2401.0386 - val_mse: 2401.0386
Epoch 63/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1996.4545 - mse: 1996.4545 - val_loss: 2411.3904 - val_mse: 2411.3904
Epoch 64/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1995.7986 - mse: 1995.7986 - val_loss: 2403.6282 - val_mse: 2403.6282
Epoch 65/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1994.1902 - mse: 1994.1902 - val_loss: 2420.5930 - val_mse: 2420.5930
Epoch 66/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1993.4651 - mse: 1993.4651 - val_loss: 2401.1953 - val_mse: 2401.1953
Epoch 67/500
2225/2225 [==============================] - 15s 7ms/step - loss: 2024.5470 - mse: 2024.5470 - val_loss: 2400.9795 - val_mse: 2400.9795
Epoch 68/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1989.7943 - mse: 1989.7943 - val_loss: 2400.9592 - val_mse: 2400.9592
Epoch 69/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1984.5483 - mse: 1984.5483 - val_loss: 2392.7205 - val_mse: 2392.7205
Epoch 70/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1987.0037 - mse: 1987.0037 - val_loss: 2406.4531 - val_mse: 2406.4531
Epoch 71/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1987.8333 - mse: 1987.8333 - val_loss: 2407.9504 - val_mse: 2407.9504
Epoch 72/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1985.6984 - mse: 1985.6984 - val_loss: 2397.6138 - val_mse: 2397.6138
Epoch 73/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1985.7360 - mse: 1985.7360 - val_loss: 2416.3306 - val_mse: 2416.3306
Epoch 74/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1985.3004 - mse: 1985.3004 - val_loss: 2391.9360 - val_mse: 2391.9360
Epoch 75/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1982.4799 - mse: 1982.4799 - val_loss: 2402.4446 - val_mse: 2402.4446
Epoch 76/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1985.8153 - mse: 1985.8153 - val_loss: 2387.0242 - val_mse: 2387.0242
Epoch 77/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1979.2655 - mse: 1979.2655 - val_loss: 2397.2800 - val_mse: 2397.2800
Epoch 78/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1978.4922 - mse: 1978.4922 - val_loss: 2398.9546 - val_mse: 2398.9546
Epoch 79/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1978.4847 - mse: 1978.4847 - val_loss: 2398.9705 - val_mse: 2398.9705
Epoch 80/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1978.7698 - mse: 1978.7698 - val_loss: 2404.4834 - val_mse: 2404.4834
Epoch 81/500
2225/2225 [==============================] - 15s 7ms/step - loss: 1978.0459 - mse: 1978.0459 - val_loss: 2407.9304 - val_mse: 2407.9304
CPU times: user 19min 15s, sys: 2min 14s, total: 21min 29s
Wall time: 20min 4s
In [69]:
!mkdir -p drive/MyDrive/datasets/autoencoder/models_animefaces
!cp model_ae_dnn.h5 drive/MyDrive/datasets/autoencoder/models_animefaces
!ls -lh drive/MyDrive/datasets/autoencoder/models_animefaces
total 74M
-rw------- 1 root root 6.0M Jun  5 13:22 model_ae_cnn.h5
-rw------- 1 root root  37M Jun  5 14:56 model_ae_dnn.h5
-rw------- 1 root root  31M Jun  5 10:41 model_ae_lstm.h5
In [58]:
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_dnn.h5'
# model.load_weights(model_file)  # Load best model
model = tf.keras.models.load_model(model_file) # Load entire model
In [70]:
model.evaluate(images_test, images_test, batch_size=8, verbose=True)
2384/2384 [==============================] - 7s 3ms/step - loss: 2376.7542 - mse: 2376.7542
Out[70]:
[2376.754150390625, 2376.754150390625]
In [71]:
def display_accuracy(model, image_actual, n_col=4, text=""):
  print("=================================== %s ===============================" % text)
  image_generated = model.predict(image_actual, batch_size=8, verbose=False).astype(np.uint8)
  image_generated[image_generated > 255] = 255
  image_generated[image_generated < 0] = 0

  images_side_by_side = np.concatenate([image_actual, image_generated], axis=2)
  plot_images(images_side_by_side, n_col=n_col)

images_to_display = 16
display_accuracy(model, images_train[:images_to_display], text="Train Output")
display_accuracy(model, images_test[:images_to_display], text="Prediction Output")
=================================== Train Output ===============================
=================================== Prediction Output ===============================

Code value - Intermediate representation of image¶

In [72]:
from tensorflow import keras

# Layers to be used
layers = [keras.Input(shape=images_shape, dtype=tf.int8)]
layers.extend(model.layers[:6])

model_code_generator = keras.Sequential(layers)
model_code_generator.build((None, images_shape[0], images_shape[1], images_shape[2]))

for layer in model_code_generator.layers:
  if list(filter(lambda x: x in layer.name, ['flatten', 'reshape'])):
    continue
  assert all([np.array_equal(layer.get_weights()[0], model.get_layer(layer.name).get_weights()[0]), 
              np.array_equal(layer.get_weights()[1], model.get_layer(layer.name).get_weights()[1])]),  "%s weights not same" % layer.name

model_code_generator.summary()
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model.
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
flatten_9 (Flatten)          (None, 12288)             0         
_________________________________________________________________
encoder_layer_1 (Dense)      (None, 128)               1572992   
_________________________________________________________________
encoder_layer_2 (Dense)      (None, 64)                8256      
_________________________________________________________________
encoder_layer_3 (Dense)      (None, 32)                2080      
_________________________________________________________________
encoder_layer_4 (Dense)      (None, 16)                528       
_________________________________________________________________
code (Dense)                 (None, 8)                 136       
=================================================================
Total params: 1,583,992
Trainable params: 1,583,992
Non-trainable params: 0
_________________________________________________________________
In [73]:
# imgs = model_code_generator.predict(images_test[:4], batch_size=8, verbose=False).astype(np.uint8)
# plot_images(imgs, n_col=8)
# imgs = model.predict(images_test[:4], batch_size=8, verbose=False).astype(np.uint8)
# plot_images(imgs, n_col=8)
In [74]:
codes = model_code_generator.predict(images_test[:16], batch_size=8, verbose=False)
codes.shape
Out[74]:
(16, 8)
In [75]:
print(codes[0].tolist())
print(codes[1].tolist())
print(codes[2].tolist())
[1215.6807861328125, 1890.59423828125, -1772.0360107421875, 381.75714111328125, -574.5813598632812, 311.3032531738281, -895.1006469726562, 616.4902954101562]
[-1336.106689453125, -1303.230712890625, -1296.32275390625, 246.3556671142578, -646.2127075195312, -150.3787384033203, -323.1648864746094, -382.8296813964844]
[703.5673217773438, 599.2479858398438, -926.3267822265625, -22.82808494567871, -1733.0594482421875, 1482.5084228515625, -673.0850219726562, -205.01351928710938]
In [76]:
code_stats = { 
    "min" : np.min(codes), 
    "max" : np.max(codes), 
    "mean": np.mean(codes),
    "std": np.std(codes)
}
code_stats
Out[76]:
{'max': 2597.495, 'mean': -68.190186, 'min': -3308.1372, 'std': 1030.903}

Lets generate some random images¶

But we need to remove some extra layers before that, now we know that code layer has 8 neurons. So we are going to generate some random 8 numbers and will pass it to out decoder layer

In [77]:
import tensorflow as tf
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_dnn.h5'
model = tf.keras.models.load_model(model_file) # Load entire model
# model.summary()
In [78]:
from tensorflow import keras
model_generator = keras.Sequential(model.layers[6:])
model_generator.build((None, 8))
model_generator.summary()
Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
decoder_layer_1 (Dense)      (None, 16)                144       
_________________________________________________________________
decoder_layer_2 (Dense)      (None, 32)                544       
_________________________________________________________________
decoder_layer_3 (Dense)      (None, 64)                2112      
_________________________________________________________________
decoder_layer_4 (Dense)      (None, 128)               8320      
_________________________________________________________________
final_layer (Dense)          (None, 12288)             1585152   
_________________________________________________________________
reshape_9 (Reshape)          (None, 64, 64, 3)         0         
=================================================================
Total params: 1,596,272
Trainable params: 1,596,272
Non-trainable params: 0
_________________________________________________________________
In [81]:
import numpy as np
inputs  = np.random.normal(code_stats['mean'], code_stats['std'], (16, 8))
# inputs = codes
image_generated = model_generator.predict(inputs, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
plot_images(image_generated, n_col=8)

Autoencoder - LSTM¶

This would be similar to Dense n/w as desribed above, but we will use LSTM layers this time

Training¶

In [11]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

model_file = 'model_ae_lstm.h5'

model = keras.Sequential(name="my_sequential")
model.add(tf.keras.layers.InputLayer(input_shape=images_shape))
model.add(layers.Reshape((images_shape[0], images_shape[1] * images_shape[2])))
model.add(tf.keras.layers.LSTM(64, activation='tanh', return_sequences=True, name="encoder_layer_1"))
model.add(tf.keras.layers.LSTM(32, activation='tanh', return_sequences=True, name="encoder_layer_2"))
model.add(tf.keras.layers.LSTM(16, activation='tanh', name="encoder_layer_3"))

model.add(layers.Dense(8, name="code"))
model.add(layers.Reshape((2, 4)))

model.add(tf.keras.layers.LSTM(16, activation='tanh', return_sequences=True, name="decoder_layer_1"))
model.add(tf.keras.layers.LSTM(32, activation='tanh', return_sequences=True, name="decoder_layer_2"))
model.add(tf.keras.layers.LSTM(64, activation='tanh', name="decoder_layer_3"))
model.add(layers.Dense(total_pixels, activation="relu", name="final_layer"))
model.add(layers.Reshape(images_shape))

checkpoint = ModelCheckpoint(model_file, verbose=0, monitor='val_loss', save_best_only=True, mode='auto')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=tf.keras.losses.MeanSquaredError(), 
              metrics=['mse']
              )
model.summary()
Model: "my_sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reshape (Reshape)            (None, 64, 192)           0         
_________________________________________________________________
encoder_layer_1 (LSTM)       (None, 64, 64)            65792     
_________________________________________________________________
encoder_layer_2 (LSTM)       (None, 64, 32)            12416     
_________________________________________________________________
encoder_layer_3 (LSTM)       (None, 16)                3136      
_________________________________________________________________
code (Dense)                 (None, 8)                 136       
_________________________________________________________________
reshape_1 (Reshape)          (None, 2, 4)              0         
_________________________________________________________________
decoder_layer_1 (LSTM)       (None, 2, 16)             1344      
_________________________________________________________________
decoder_layer_2 (LSTM)       (None, 2, 32)             6272      
_________________________________________________________________
decoder_layer_3 (LSTM)       (None, 64)                24832     
_________________________________________________________________
final_layer (Dense)          (None, 12288)             798720    
_________________________________________________________________
reshape_2 (Reshape)          (None, 64, 64, 3)         0         
=================================================================
Total params: 912,648
Trainable params: 912,648
Non-trainable params: 0
_________________________________________________________________
In [12]:
%%time
model.fit(images_train, images_train, batch_size=16, epochs=500, validation_split=0.2, callbacks=[checkpoint, early_stopping], shuffle=True)
model.save(model_file) # Save Best model to disk
Epoch 1/500
2225/2225 [==============================] - 47s 15ms/step - loss: 17278.6152 - mse: 17278.6152 - val_loss: 10705.0771 - val_mse: 10705.0771
Epoch 2/500
2225/2225 [==============================] - 33s 15ms/step - loss: 9260.0801 - mse: 9260.0801 - val_loss: 8613.3975 - val_mse: 8613.3975
Epoch 3/500
2225/2225 [==============================] - 33s 15ms/step - loss: 8594.0107 - mse: 8594.0107 - val_loss: 8547.8838 - val_mse: 8547.8838
Epoch 4/500
2225/2225 [==============================] - 33s 15ms/step - loss: 8581.1357 - mse: 8581.1357 - val_loss: 8547.7734 - val_mse: 8547.7734
Epoch 5/500
2225/2225 [==============================] - 32s 14ms/step - loss: 8581.2842 - mse: 8581.2842 - val_loss: 8547.7871 - val_mse: 8547.7871
Epoch 6/500
2225/2225 [==============================] - 32s 14ms/step - loss: 8581.1934 - mse: 8581.1934 - val_loss: 8547.9473 - val_mse: 8547.9473
Epoch 7/500
2225/2225 [==============================] - 32s 14ms/step - loss: 8576.4355 - mse: 8576.4355 - val_loss: 8525.8848 - val_mse: 8525.8848
Epoch 8/500
2225/2225 [==============================] - 32s 14ms/step - loss: 8556.7031 - mse: 8556.7031 - val_loss: 8523.1943 - val_mse: 8523.1943
Epoch 9/500
2225/2225 [==============================] - 32s 15ms/step - loss: 8556.5117 - mse: 8556.5117 - val_loss: 8523.4219 - val_mse: 8523.4219
Epoch 10/500
2225/2225 [==============================] - 32s 14ms/step - loss: 8547.7285 - mse: 8547.7285 - val_loss: 8353.8105 - val_mse: 8353.8105
Epoch 11/500
2225/2225 [==============================] - 32s 14ms/step - loss: 8227.0508 - mse: 8227.0508 - val_loss: 8235.6152 - val_mse: 8235.6152
Epoch 12/500
2225/2225 [==============================] - 32s 14ms/step - loss: 8137.2378 - mse: 8137.2378 - val_loss: 8029.8247 - val_mse: 8029.8247
Epoch 13/500
2225/2225 [==============================] - 32s 14ms/step - loss: 7558.8306 - mse: 7558.8306 - val_loss: 7133.8076 - val_mse: 7133.8076
Epoch 14/500
2225/2225 [==============================] - 32s 14ms/step - loss: 6959.6494 - mse: 6959.6494 - val_loss: 6826.3511 - val_mse: 6826.3511
Epoch 15/500
2225/2225 [==============================] - 32s 15ms/step - loss: 6767.2334 - mse: 6767.2334 - val_loss: 6688.7368 - val_mse: 6688.7368
Epoch 16/500
2225/2225 [==============================] - 33s 15ms/step - loss: 6598.5073 - mse: 6598.5073 - val_loss: 6468.4165 - val_mse: 6468.4165
Epoch 17/500
2225/2225 [==============================] - 33s 15ms/step - loss: 6368.7534 - mse: 6368.7534 - val_loss: 6244.3647 - val_mse: 6244.3647
Epoch 18/500
2225/2225 [==============================] - 33s 15ms/step - loss: 6151.1299 - mse: 6151.1299 - val_loss: 6091.9585 - val_mse: 6091.9585
Epoch 19/500
2225/2225 [==============================] - 33s 15ms/step - loss: 6023.8628 - mse: 6023.8628 - val_loss: 5955.0894 - val_mse: 5955.0894
Epoch 20/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5918.6499 - mse: 5918.6499 - val_loss: 5845.9141 - val_mse: 5845.9141
Epoch 21/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5840.9941 - mse: 5840.9941 - val_loss: 5792.4121 - val_mse: 5792.4121
Epoch 22/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5797.5967 - mse: 5797.5967 - val_loss: 5735.8999 - val_mse: 5735.8999
Epoch 23/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5742.4775 - mse: 5742.4775 - val_loss: 5697.0938 - val_mse: 5697.0938
Epoch 24/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5682.7651 - mse: 5682.7651 - val_loss: 5656.0010 - val_mse: 5656.0010
Epoch 25/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5639.8833 - mse: 5639.8833 - val_loss: 5594.0493 - val_mse: 5594.0493
Epoch 26/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5588.4316 - mse: 5588.4316 - val_loss: 5547.0728 - val_mse: 5547.0728
Epoch 27/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5540.7373 - mse: 5540.7373 - val_loss: 5540.0312 - val_mse: 5540.0312
Epoch 28/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5475.1758 - mse: 5475.1758 - val_loss: 5439.8794 - val_mse: 5439.8794
Epoch 29/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5434.9487 - mse: 5434.9487 - val_loss: 5411.1914 - val_mse: 5411.1914
Epoch 30/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5402.4243 - mse: 5402.4243 - val_loss: 5433.1772 - val_mse: 5433.1772
Epoch 31/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5370.1484 - mse: 5370.1484 - val_loss: 5325.5493 - val_mse: 5325.5493
Epoch 32/500
2225/2225 [==============================] - 35s 16ms/step - loss: 5324.7983 - mse: 5324.7983 - val_loss: 5298.7197 - val_mse: 5298.7197
Epoch 33/500
2225/2225 [==============================] - 34s 15ms/step - loss: 5269.9653 - mse: 5269.9653 - val_loss: 5215.4131 - val_mse: 5215.4131
Epoch 34/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5208.8257 - mse: 5208.8257 - val_loss: 5173.6958 - val_mse: 5173.6958
Epoch 35/500
2225/2225 [==============================] - 34s 15ms/step - loss: 5146.5547 - mse: 5146.5547 - val_loss: 5082.8062 - val_mse: 5082.8062
Epoch 36/500
2225/2225 [==============================] - 33s 15ms/step - loss: 5085.5649 - mse: 5085.5649 - val_loss: 5106.9917 - val_mse: 5106.9917
Epoch 37/500
2225/2225 [==============================] - 32s 15ms/step - loss: 5068.1494 - mse: 5068.1494 - val_loss: 5122.3872 - val_mse: 5122.3872
Epoch 38/500
2225/2225 [==============================] - 32s 14ms/step - loss: 5073.0225 - mse: 5073.0225 - val_loss: 5051.2993 - val_mse: 5051.2993
Epoch 39/500
2225/2225 [==============================] - 32s 14ms/step - loss: 5058.7979 - mse: 5058.7979 - val_loss: 5053.2832 - val_mse: 5053.2832
Epoch 40/500
2225/2225 [==============================] - 32s 14ms/step - loss: 5037.4282 - mse: 5037.4282 - val_loss: 4976.8667 - val_mse: 4976.8667
Epoch 41/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4939.2544 - mse: 4939.2544 - val_loss: 4872.5215 - val_mse: 4872.5215
Epoch 42/500
2225/2225 [==============================] - 32s 15ms/step - loss: 4829.6265 - mse: 4829.6265 - val_loss: 4772.4727 - val_mse: 4772.4727
Epoch 43/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4742.4395 - mse: 4742.4395 - val_loss: 4697.9766 - val_mse: 4697.9766
Epoch 44/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4682.7646 - mse: 4682.7646 - val_loss: 4663.2148 - val_mse: 4663.2148
Epoch 45/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4615.2954 - mse: 4615.2954 - val_loss: 4563.9487 - val_mse: 4563.9487
Epoch 46/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4535.0283 - mse: 4535.0283 - val_loss: 4484.2603 - val_mse: 4484.2603
Epoch 47/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4435.7251 - mse: 4435.7251 - val_loss: 4380.4775 - val_mse: 4380.4775
Epoch 48/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4282.7090 - mse: 4282.7090 - val_loss: 4190.1650 - val_mse: 4190.1650
Epoch 49/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4159.7949 - mse: 4159.7949 - val_loss: 4125.0225 - val_mse: 4125.0225
Epoch 50/500
2225/2225 [==============================] - 32s 14ms/step - loss: 4088.7727 - mse: 4088.7727 - val_loss: 4081.5081 - val_mse: 4081.5081
Epoch 51/500
2225/2225 [==============================] - 32s 14ms/step - loss: 3796.1492 - mse: 3796.1492 - val_loss: 3567.1746 - val_mse: 3567.1746
Epoch 52/500
2225/2225 [==============================] - 32s 14ms/step - loss: 3517.8994 - mse: 3517.8994 - val_loss: 3468.4580 - val_mse: 3468.4580
Epoch 53/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3443.5837 - mse: 3443.5837 - val_loss: 3410.4177 - val_mse: 3410.4177
Epoch 54/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3397.5950 - mse: 3397.5950 - val_loss: 3372.3838 - val_mse: 3372.3838
Epoch 55/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3355.9170 - mse: 3355.9170 - val_loss: 3345.6899 - val_mse: 3345.6899
Epoch 56/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3325.5049 - mse: 3325.5049 - val_loss: 3329.6016 - val_mse: 3329.6016
Epoch 57/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3298.4011 - mse: 3298.4011 - val_loss: 3280.9995 - val_mse: 3280.9995
Epoch 58/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3226.9290 - mse: 3226.9290 - val_loss: 3194.6494 - val_mse: 3194.6494
Epoch 59/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3165.4016 - mse: 3165.4016 - val_loss: 3141.0205 - val_mse: 3141.0205
Epoch 60/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3126.9771 - mse: 3126.9771 - val_loss: 3118.6807 - val_mse: 3118.6807
Epoch 61/500
2225/2225 [==============================] - 33s 15ms/step - loss: 3091.4814 - mse: 3091.4814 - val_loss: 3055.6978 - val_mse: 3055.6978
Epoch 62/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2996.6882 - mse: 2996.6882 - val_loss: 2959.2883 - val_mse: 2959.2883
Epoch 63/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2914.3850 - mse: 2914.3850 - val_loss: 2868.0161 - val_mse: 2868.0161
Epoch 64/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2819.0962 - mse: 2819.0962 - val_loss: 2770.9736 - val_mse: 2770.9736
Epoch 65/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2718.3696 - mse: 2718.3696 - val_loss: 2690.1924 - val_mse: 2690.1924
Epoch 66/500
2225/2225 [==============================] - 32s 15ms/step - loss: 2582.4478 - mse: 2582.4478 - val_loss: 2506.6973 - val_mse: 2506.6973
Epoch 67/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2488.1816 - mse: 2488.1816 - val_loss: 2480.1951 - val_mse: 2480.1951
Epoch 68/500
2225/2225 [==============================] - 34s 15ms/step - loss: 2467.9194 - mse: 2467.9194 - val_loss: 2478.2026 - val_mse: 2478.2026
Epoch 69/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2451.6931 - mse: 2451.6931 - val_loss: 2444.1040 - val_mse: 2444.1040
Epoch 70/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2432.6746 - mse: 2432.6746 - val_loss: 2439.4841 - val_mse: 2439.4841
Epoch 71/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2416.6155 - mse: 2416.6155 - val_loss: 2406.9717 - val_mse: 2406.9717
Epoch 72/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2396.0857 - mse: 2396.0857 - val_loss: 2392.3086 - val_mse: 2392.3086
Epoch 73/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2386.6431 - mse: 2386.6431 - val_loss: 2387.3894 - val_mse: 2387.3894
Epoch 74/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2362.8672 - mse: 2362.8672 - val_loss: 2362.9543 - val_mse: 2362.9543
Epoch 75/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2347.3027 - mse: 2347.3027 - val_loss: 2347.9116 - val_mse: 2347.9116
Epoch 76/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2335.6294 - mse: 2335.6294 - val_loss: 2339.1584 - val_mse: 2339.1584
Epoch 77/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2324.3877 - mse: 2324.3877 - val_loss: 2324.0425 - val_mse: 2324.0425
Epoch 78/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2313.9614 - mse: 2313.9614 - val_loss: 2316.0813 - val_mse: 2316.0813
Epoch 79/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2308.7534 - mse: 2308.7534 - val_loss: 2303.6992 - val_mse: 2303.6992
Epoch 80/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2304.8267 - mse: 2304.8267 - val_loss: 2300.6892 - val_mse: 2300.6892
Epoch 81/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2295.8223 - mse: 2295.8223 - val_loss: 2318.0796 - val_mse: 2318.0796
Epoch 82/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2291.9814 - mse: 2291.9814 - val_loss: 2294.4219 - val_mse: 2294.4219
Epoch 83/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2285.5154 - mse: 2285.5154 - val_loss: 2280.4534 - val_mse: 2280.4534
Epoch 84/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2279.4600 - mse: 2279.4600 - val_loss: 2280.0750 - val_mse: 2280.0750
Epoch 85/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2270.3259 - mse: 2270.3259 - val_loss: 2264.2275 - val_mse: 2264.2275
Epoch 86/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2260.8247 - mse: 2260.8247 - val_loss: 2262.0779 - val_mse: 2262.0779
Epoch 87/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2251.9761 - mse: 2251.9761 - val_loss: 2253.4373 - val_mse: 2253.4373
Epoch 88/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2247.7524 - mse: 2247.7524 - val_loss: 2266.4182 - val_mse: 2266.4182
Epoch 89/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2242.2366 - mse: 2242.2366 - val_loss: 2243.9790 - val_mse: 2243.9790
Epoch 90/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2237.4939 - mse: 2237.4939 - val_loss: 2237.2122 - val_mse: 2237.2122
Epoch 91/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2228.7690 - mse: 2228.7690 - val_loss: 2229.9817 - val_mse: 2229.9817
Epoch 92/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2219.3467 - mse: 2219.3467 - val_loss: 2235.1709 - val_mse: 2235.1709
Epoch 93/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2211.1423 - mse: 2211.1423 - val_loss: 2209.8882 - val_mse: 2209.8882
Epoch 94/500
2225/2225 [==============================] - 32s 14ms/step - loss: 2203.2751 - mse: 2203.2751 - val_loss: 2204.6851 - val_mse: 2204.6851
Epoch 95/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2196.6384 - mse: 2196.6384 - val_loss: 2197.9944 - val_mse: 2197.9944
Epoch 96/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2191.7378 - mse: 2191.7378 - val_loss: 2192.9355 - val_mse: 2192.9355
Epoch 97/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2185.7500 - mse: 2185.7500 - val_loss: 2190.8318 - val_mse: 2190.8318
Epoch 98/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2181.7219 - mse: 2181.7219 - val_loss: 2187.3789 - val_mse: 2187.3789
Epoch 99/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2178.2847 - mse: 2178.2847 - val_loss: 2184.3501 - val_mse: 2184.3501
Epoch 100/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2171.7197 - mse: 2171.7197 - val_loss: 2172.8926 - val_mse: 2172.8926
Epoch 101/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2168.5957 - mse: 2168.5957 - val_loss: 2169.3792 - val_mse: 2169.3792
Epoch 102/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2165.2861 - mse: 2165.2861 - val_loss: 2173.1895 - val_mse: 2173.1895
Epoch 103/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2163.6592 - mse: 2163.6592 - val_loss: 2162.5911 - val_mse: 2162.5911
Epoch 104/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2157.7532 - mse: 2157.7532 - val_loss: 2177.4827 - val_mse: 2177.4827
Epoch 105/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2155.2971 - mse: 2155.2971 - val_loss: 2160.7231 - val_mse: 2160.7231
Epoch 106/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2152.8140 - mse: 2152.8140 - val_loss: 2153.8572 - val_mse: 2153.8572
Epoch 107/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2149.6211 - mse: 2149.6211 - val_loss: 2153.2310 - val_mse: 2153.2310
Epoch 108/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2148.5159 - mse: 2148.5159 - val_loss: 2154.9385 - val_mse: 2154.9385
Epoch 109/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2144.4446 - mse: 2144.4446 - val_loss: 2147.2644 - val_mse: 2147.2644
Epoch 110/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2142.0940 - mse: 2142.0940 - val_loss: 2144.1699 - val_mse: 2144.1699
Epoch 111/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2138.1355 - mse: 2138.1355 - val_loss: 2143.0337 - val_mse: 2143.0337
Epoch 112/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2135.3999 - mse: 2135.3999 - val_loss: 2140.4258 - val_mse: 2140.4258
Epoch 113/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2137.5798 - mse: 2137.5798 - val_loss: 2151.2314 - val_mse: 2151.2314
Epoch 114/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2135.3606 - mse: 2135.3606 - val_loss: 2139.0652 - val_mse: 2139.0652
Epoch 115/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2134.5042 - mse: 2134.5042 - val_loss: 2142.6350 - val_mse: 2142.6350
Epoch 116/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2129.5422 - mse: 2129.5422 - val_loss: 2135.6750 - val_mse: 2135.6750
Epoch 117/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2128.1431 - mse: 2128.1431 - val_loss: 2142.6777 - val_mse: 2142.6777
Epoch 118/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2124.3354 - mse: 2124.3354 - val_loss: 2126.6646 - val_mse: 2126.6646
Epoch 119/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2122.3333 - mse: 2122.3333 - val_loss: 2125.1572 - val_mse: 2125.1572
Epoch 120/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2123.1204 - mse: 2123.1204 - val_loss: 2126.4717 - val_mse: 2126.4717
Epoch 121/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2117.8911 - mse: 2117.8911 - val_loss: 2135.2971 - val_mse: 2135.2971
Epoch 122/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2120.0750 - mse: 2120.0750 - val_loss: 2123.0679 - val_mse: 2123.0679
Epoch 123/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2115.5303 - mse: 2115.5303 - val_loss: 2122.5771 - val_mse: 2122.5771
Epoch 124/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2112.2561 - mse: 2112.2561 - val_loss: 2121.9148 - val_mse: 2121.9148
Epoch 125/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2111.0742 - mse: 2111.0742 - val_loss: 2117.6472 - val_mse: 2117.6472
Epoch 126/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2110.0996 - mse: 2110.0996 - val_loss: 2116.0806 - val_mse: 2116.0806
Epoch 127/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2114.4026 - mse: 2114.4026 - val_loss: 2120.9204 - val_mse: 2120.9204
Epoch 128/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2108.3433 - mse: 2108.3433 - val_loss: 2113.1472 - val_mse: 2113.1472
Epoch 129/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2106.2334 - mse: 2106.2334 - val_loss: 2127.8787 - val_mse: 2127.8787
Epoch 130/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2104.1953 - mse: 2104.1953 - val_loss: 2113.1770 - val_mse: 2113.1770
Epoch 131/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2104.2295 - mse: 2104.2295 - val_loss: 2107.2539 - val_mse: 2107.2539
Epoch 132/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2100.5710 - mse: 2100.5710 - val_loss: 2118.6841 - val_mse: 2118.6841
Epoch 133/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2101.0947 - mse: 2101.0947 - val_loss: 2108.8284 - val_mse: 2108.8284
Epoch 134/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2097.8513 - mse: 2097.8513 - val_loss: 2108.7864 - val_mse: 2108.7864
Epoch 135/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2100.6631 - mse: 2100.6631 - val_loss: 2099.6821 - val_mse: 2099.6821
Epoch 136/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2098.4268 - mse: 2098.4268 - val_loss: 2106.6753 - val_mse: 2106.6753
Epoch 137/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2095.9814 - mse: 2095.9814 - val_loss: 2129.8088 - val_mse: 2129.8088
Epoch 138/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2096.1062 - mse: 2096.1062 - val_loss: 2100.0891 - val_mse: 2100.0891
Epoch 139/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2091.3901 - mse: 2091.3901 - val_loss: 2098.1531 - val_mse: 2098.1531
Epoch 140/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2089.8413 - mse: 2089.8413 - val_loss: 2098.5381 - val_mse: 2098.5381
Epoch 141/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2091.3118 - mse: 2091.3118 - val_loss: 2101.8169 - val_mse: 2101.8169
Epoch 142/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2094.5698 - mse: 2094.5698 - val_loss: 2110.2651 - val_mse: 2110.2651
Epoch 143/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2090.4109 - mse: 2090.4109 - val_loss: 2098.2056 - val_mse: 2098.2056
Epoch 144/500
2225/2225 [==============================] - 33s 15ms/step - loss: 2085.6516 - mse: 2085.6516 - val_loss: 2099.2151 - val_mse: 2099.2151
CPU times: user 1h 28min 50s, sys: 5min 56s, total: 1h 34min 46s
Wall time: 1h 18min 44s
In [13]:
!mkdir -p drive/MyDrive/datasets/autoencoder/models_animefaces
!cp model_ae_lstm.h5 drive/MyDrive/datasets/autoencoder/models_animefaces
!ls -lh drive/MyDrive/datasets/autoencoder/models_animefaces
total 53M
-rw------- 1 root root 6.0M Jun  5 13:22 model_ae_cnn.h5
-rw------- 1 root root  37M Jun  5 14:56 model_ae_dnn.h5
-rw------- 1 root root  11M Jun  6 07:45 model_ae_lstm.h5
In [14]:
# model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_lstm.h5'
# model.load_weights(model_file)  # Load best model
model = tf.keras.models.load_model(model_file) # Load entire model
In [15]:
model.evaluate(images_test, images_test, batch_size=8, verbose=True)
2384/2384 [==============================] - 17s 6ms/step - loss: 2092.5908 - mse: 2092.5908
Out[15]:
[2092.5908203125, 2092.5908203125]
In [16]:
def display_accuracy(model, image_actual, n_col=4, text=""):
  print("=================================== %s ===============================" % text)
  image_generated = model.predict(image_actual, batch_size=8, verbose=False).astype(np.uint8)
  image_generated[image_generated > 255] = 255
  image_generated[image_generated < 0] = 0

  images_side_by_side = np.concatenate([image_actual, image_generated], axis=2)
  plot_images(images_side_by_side, n_col=n_col)

images_to_display = 16
display_accuracy(model, images_train[:images_to_display], text="Train Output")
display_accuracy(model, images_test[:images_to_display], text="Prediction Output")
=================================== Train Output ===============================
=================================== Prediction Output ===============================

Code value - Intermediate representation of image¶

In [20]:
from tensorflow import keras

# Layers to be used
layers = [tf.keras.layers.InputLayer(input_shape=images_shape)]
layers.extend(model.layers[:5])

model_code_generator = keras.Sequential(layers)
model_code_generator.build((None, images_shape[0], images_shape[1], images_shape[2]))

for layer in model_code_generator.layers:
  if list(filter(lambda x: x in layer.name, ['flatten', 'reshape'])):
    continue
  assert all([np.array_equal(layer.get_weights()[0], model.get_layer(layer.name).get_weights()[0]), 
              np.array_equal(layer.get_weights()[1], model.get_layer(layer.name).get_weights()[1])]),  "%s weights not same" % layer.name

model_code_generator.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reshape (Reshape)            (None, 64, 192)           0         
_________________________________________________________________
encoder_layer_1 (LSTM)       (None, 64, 64)            65792     
_________________________________________________________________
encoder_layer_2 (LSTM)       (None, 64, 32)            12416     
_________________________________________________________________
encoder_layer_3 (LSTM)       (None, 16)                3136      
_________________________________________________________________
code (Dense)                 (None, 8)                 136       
=================================================================
Total params: 81,480
Trainable params: 81,480
Non-trainable params: 0
_________________________________________________________________
In [21]:
codes = model_code_generator.predict(images_test[:16], batch_size=8, verbose=False)
codes.shape
Out[21]:
(16, 8)
In [22]:
print(codes[0].tolist())
print(codes[1].tolist())
print(codes[2].tolist())
[-0.26205557584762573, -0.33697181940078735, -0.3860102891921997, 0.07216334342956543, -0.08162796497344971, 0.21333178877830505, 0.44444242119789124, -0.7647983431816101]
[-0.411038339138031, 0.4667309522628784, -0.6031712293624878, -0.2982359230518341, 0.4706922173500061, -0.06073388457298279, -0.12359648942947388, 0.189530611038208]
[-0.05309383571147919, -0.9304435849189758, -0.7033981680870056, -0.44504719972610474, -0.6713100671768188, 1.3494720458984375, -0.4281803071498871, -0.707094132900238]
In [23]:
code_stats = { 
    "min" : np.min(codes), 
    "max" : np.max(codes), 
    "mean": np.mean(codes),
    "std": np.std(codes)
}
code_stats
Out[23]:
{'max': 1.349472, 'mean': -0.058297068, 'min': -1.4627534, 'std': 0.5171938}

Lets generate some random images¶

But we need to remove some extra layers before that, now we know that code layer has 8 neurons. So we are going to generate some random 8 numbers and will pass it to out decoder layer

In [24]:
import tensorflow as tf
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_lstm.h5'
model = tf.keras.models.load_model(model_file) # Load entire model
# model.summary()
In [26]:
from tensorflow import keras
model_generator = keras.Sequential(model.layers[5:])
model_generator.build((None, 8))
model_generator.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reshape_1 (Reshape)          (None, 2, 4)              0         
_________________________________________________________________
decoder_layer_1 (LSTM)       (None, 2, 16)             1344      
_________________________________________________________________
decoder_layer_2 (LSTM)       (None, 2, 32)             6272      
_________________________________________________________________
decoder_layer_3 (LSTM)       (None, 64)                24832     
_________________________________________________________________
final_layer (Dense)          (None, 12288)             798720    
_________________________________________________________________
reshape_2 (Reshape)          (None, 64, 64, 3)         0         
=================================================================
Total params: 831,168
Trainable params: 831,168
Non-trainable params: 0
_________________________________________________________________
In [30]:
import numpy as np
inputs  = np.random.normal(code_stats['mean'], code_stats['std'], (16, 8))
# inputs = codes
image_generated = model_generator.predict(inputs, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
plot_images(image_generated, n_col=8)

Autoencoder - CNN¶

This would be similar to Dense n/w as desribed above, but we will use CNN layers this time

Training¶

In [13]:
# from numba import cuda 
# device = cuda.get_current_device()
# device.reset()
In [11]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint

model_file = 'model_ae_cnn.h5'

filter_size = (3, 3)
model = keras.Sequential(name="autoencoder_cnn")

model.add(tf.keras.layers.InputLayer(input_shape=images_shape))
# model.add(tf.keras.layers.Conv2D(256, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_1'))
model.add(tf.keras.layers.Conv2D(128, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_1'))
model.add(tf.keras.layers.Conv2D(64, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_2'))
model.add(tf.keras.layers.Conv2D(32, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_3'))
model.add(tf.keras.layers.Conv2D(16, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_4'))

model.add(layers.Flatten())
model.add(layers.Dense(8, name="code"))
model.add(layers.Reshape((2, 2, 2)))

model.add(tf.keras.layers.Conv2DTranspose(16, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_1'))
model.add(tf.keras.layers.Conv2DTranspose(32, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_2'))
model.add(tf.keras.layers.Conv2DTranspose(64, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_3'))
model.add(tf.keras.layers.Conv2DTranspose(128, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_4'))
# model.add(tf.keras.layers.Conv2DTranspose(256, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_5'))

model.add(tf.keras.layers.Conv2DTranspose(3, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_6'))

checkpoint = ModelCheckpoint(model_file, verbose=0, monitor='val_loss', save_best_only=True, mode='auto')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss=tf.keras.losses.MeanSquaredError(), 
              metrics=['mse']
              )
model.summary()
Model: "autoencoder_cnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_layer_1 (Conv2D)     (None, 32, 32, 128)       3584      
_________________________________________________________________
encoder_layer_2 (Conv2D)     (None, 16, 16, 64)        73792     
_________________________________________________________________
encoder_layer_3 (Conv2D)     (None, 8, 8, 32)          18464     
_________________________________________________________________
encoder_layer_4 (Conv2D)     (None, 4, 4, 16)          4624      
_________________________________________________________________
flatten (Flatten)            (None, 256)               0         
_________________________________________________________________
code (Dense)                 (None, 8)                 2056      
_________________________________________________________________
reshape (Reshape)            (None, 2, 2, 2)           0         
_________________________________________________________________
decoder_layer_1 (Conv2DTrans (None, 4, 4, 16)          304       
_________________________________________________________________
decoder_layer_2 (Conv2DTrans (None, 8, 8, 32)          4640      
_________________________________________________________________
decoder_layer_3 (Conv2DTrans (None, 16, 16, 64)        18496     
_________________________________________________________________
decoder_layer_4 (Conv2DTrans (None, 32, 32, 128)       73856     
_________________________________________________________________
decoder_layer_6 (Conv2DTrans (None, 64, 64, 3)         3459      
=================================================================
Total params: 203,275
Trainable params: 203,275
Non-trainable params: 0
_________________________________________________________________
In [39]:
%%time
model.fit(images_train, images_train, batch_size=32, epochs=500, validation_split=0.2, callbacks=[checkpoint, early_stopping], shuffle=True)
model.save(model_file) # Save Best model to disk
Epoch 1/500
1113/1113 [==============================] - 42s 14ms/step - loss: 3058.9998 - mse: 3058.9998 - val_loss: 2511.7271 - val_mse: 2511.7271
Epoch 2/500
1113/1113 [==============================] - 15s 14ms/step - loss: 2468.7471 - mse: 2468.7471 - val_loss: 2443.8110 - val_mse: 2443.8110
Epoch 3/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2411.2112 - mse: 2411.2112 - val_loss: 2394.8528 - val_mse: 2394.8528
Epoch 4/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2384.5640 - mse: 2384.5640 - val_loss: 2383.0754 - val_mse: 2383.0754
Epoch 5/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2369.5110 - mse: 2369.5110 - val_loss: 2363.0452 - val_mse: 2363.0452
Epoch 6/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2359.4460 - mse: 2359.4460 - val_loss: 2373.9758 - val_mse: 2373.9758
Epoch 7/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2339.5972 - mse: 2339.5972 - val_loss: 2325.3003 - val_mse: 2325.3003
Epoch 8/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2319.3804 - mse: 2319.3804 - val_loss: 2307.9702 - val_mse: 2307.9702
Epoch 9/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2300.3269 - mse: 2300.3269 - val_loss: 2301.1951 - val_mse: 2301.1951
Epoch 10/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2286.0310 - mse: 2286.0310 - val_loss: 2281.6335 - val_mse: 2281.6335
Epoch 11/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2275.1306 - mse: 2275.1306 - val_loss: 2270.2104 - val_mse: 2270.2104
Epoch 12/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2264.5574 - mse: 2264.5574 - val_loss: 2274.3042 - val_mse: 2274.3042
Epoch 13/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2257.6902 - mse: 2257.6902 - val_loss: 2270.7708 - val_mse: 2270.7708
Epoch 14/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2253.0427 - mse: 2253.0427 - val_loss: 2251.2915 - val_mse: 2251.2915
Epoch 15/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2247.2974 - mse: 2247.2974 - val_loss: 2247.7493 - val_mse: 2247.7493
Epoch 16/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2241.0981 - mse: 2241.0981 - val_loss: 2249.4192 - val_mse: 2249.4192
Epoch 17/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2237.9487 - mse: 2237.9487 - val_loss: 2238.3975 - val_mse: 2238.3975
Epoch 18/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2231.9199 - mse: 2231.9199 - val_loss: 2270.9529 - val_mse: 2270.9529
Epoch 19/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2227.9573 - mse: 2227.9573 - val_loss: 2248.0601 - val_mse: 2248.0601
Epoch 20/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2223.6274 - mse: 2223.6274 - val_loss: 2235.0117 - val_mse: 2235.0117
Epoch 21/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2219.8855 - mse: 2219.8855 - val_loss: 2221.4517 - val_mse: 2221.4517
Epoch 22/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2214.5239 - mse: 2214.5239 - val_loss: 2222.6326 - val_mse: 2222.6326
Epoch 23/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2209.3633 - mse: 2209.3633 - val_loss: 2233.6631 - val_mse: 2233.6631
Epoch 24/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2203.8315 - mse: 2203.8315 - val_loss: 2217.5784 - val_mse: 2217.5784
Epoch 25/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2197.1934 - mse: 2197.1934 - val_loss: 2205.1272 - val_mse: 2205.1272
Epoch 26/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2191.9189 - mse: 2191.9189 - val_loss: 2202.4661 - val_mse: 2202.4661
Epoch 27/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2185.0027 - mse: 2185.0027 - val_loss: 2192.4851 - val_mse: 2192.4851
Epoch 28/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2176.2422 - mse: 2176.2422 - val_loss: 2181.2500 - val_mse: 2181.2500
Epoch 29/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2169.8945 - mse: 2169.8945 - val_loss: 2177.0142 - val_mse: 2177.0142
Epoch 30/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2163.1399 - mse: 2163.1399 - val_loss: 2170.9331 - val_mse: 2170.9331
Epoch 31/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2158.5430 - mse: 2158.5430 - val_loss: 2165.5391 - val_mse: 2165.5391
Epoch 32/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2153.6353 - mse: 2153.6353 - val_loss: 2169.0059 - val_mse: 2169.0059
Epoch 33/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2150.2375 - mse: 2150.2375 - val_loss: 2184.1431 - val_mse: 2184.1431
Epoch 34/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2146.2983 - mse: 2146.2983 - val_loss: 2160.8652 - val_mse: 2160.8652
Epoch 35/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2141.6951 - mse: 2141.6951 - val_loss: 2162.7627 - val_mse: 2162.7627
Epoch 36/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2140.3684 - mse: 2140.3684 - val_loss: 2148.9524 - val_mse: 2148.9524
Epoch 37/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2138.5645 - mse: 2138.5645 - val_loss: 2149.0491 - val_mse: 2149.0491
Epoch 38/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2136.0417 - mse: 2136.0417 - val_loss: 2139.4277 - val_mse: 2139.4277
Epoch 39/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2134.2087 - mse: 2134.2087 - val_loss: 2147.2849 - val_mse: 2147.2849
Epoch 40/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2132.8669 - mse: 2132.8669 - val_loss: 2142.0872 - val_mse: 2142.0872
Epoch 41/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2128.8611 - mse: 2128.8611 - val_loss: 2143.2449 - val_mse: 2143.2449
Epoch 42/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2128.0315 - mse: 2128.0315 - val_loss: 2138.1052 - val_mse: 2138.1052
Epoch 43/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2125.8042 - mse: 2125.8042 - val_loss: 2137.4983 - val_mse: 2137.4983
Epoch 44/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2126.8896 - mse: 2126.8896 - val_loss: 2135.8518 - val_mse: 2135.8518
Epoch 45/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2125.1531 - mse: 2125.1531 - val_loss: 2145.0603 - val_mse: 2145.0603
Epoch 46/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2123.4487 - mse: 2123.4487 - val_loss: 2131.9419 - val_mse: 2131.9419
Epoch 47/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2121.2905 - mse: 2121.2905 - val_loss: 2146.5078 - val_mse: 2146.5078
Epoch 48/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2119.4956 - mse: 2119.4956 - val_loss: 2132.5427 - val_mse: 2132.5427
Epoch 49/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2118.7007 - mse: 2118.7007 - val_loss: 2127.1130 - val_mse: 2127.1130
Epoch 50/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2117.9265 - mse: 2117.9265 - val_loss: 2132.1787 - val_mse: 2132.1787
Epoch 51/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2117.5581 - mse: 2117.5581 - val_loss: 2129.9849 - val_mse: 2129.9849
Epoch 52/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2117.7019 - mse: 2117.7019 - val_loss: 2128.9792 - val_mse: 2128.9792
Epoch 53/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2115.8765 - mse: 2115.8765 - val_loss: 2128.2673 - val_mse: 2128.2673
Epoch 54/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2113.9507 - mse: 2113.9507 - val_loss: 2121.6313 - val_mse: 2121.6313
Epoch 55/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2113.0039 - mse: 2113.0039 - val_loss: 2128.6384 - val_mse: 2128.6384
Epoch 56/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2110.9111 - mse: 2110.9111 - val_loss: 2128.9453 - val_mse: 2128.9453
Epoch 57/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2111.1548 - mse: 2111.1548 - val_loss: 2123.9087 - val_mse: 2123.9087
Epoch 58/500
1113/1113 [==============================] - 16s 14ms/step - loss: 2110.3159 - mse: 2110.3159 - val_loss: 2124.1143 - val_mse: 2124.1143
Epoch 59/500
1113/1113 [==============================] - 16s 15ms/step - loss: 2112.7029 - mse: 2112.7029 - val_loss: 2127.9968 - val_mse: 2127.9968
CPU times: user 14min 33s, sys: 53.8 s, total: 15min 26s
Wall time: 16min 17s
In [40]:
!mkdir -p drive/MyDrive/datasets/autoencoder/models_animefaces
!cp model_ae_cnn.h5 drive/MyDrive/datasets/autoencoder/models_animefaces
!ls -lh drive/MyDrive/datasets/autoencoder/models_animefaces
total 50M
-rw------- 1 root root 2.5M Jun  6 08:16 model_ae_cnn.h5
-rw------- 1 root root  37M Jun  5 14:56 model_ae_dnn.h5
-rw------- 1 root root  11M Jun  6 07:45 model_ae_lstm.h5
In [12]:
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_cnn.h5'
# model.load_weights(model_file)  # Load best model
model = tf.keras.models.load_model(model_file) # Load entire model
In [41]:
model.evaluate(images_test, images_test, batch_size=8, verbose=True)
2384/2384 [==============================] - 10s 4ms/step - loss: 2117.5776 - mse: 2117.5776
Out[41]:
[2117.57763671875, 2117.57763671875]
In [42]:
def display_accuracy(model, image_actual, n_col=4, text=""):
  print("=================================== %s ===============================" % text)
  image_generated = model.predict(image_actual, batch_size=8, verbose=False).astype(np.uint8)
  image_generated[image_generated > 255] = 255
  image_generated[image_generated < 0] = 0

  images_side_by_side = np.concatenate([image_actual, image_generated], axis=2)
  plot_images(images_side_by_side, n_col=n_col)

images_to_display = 16
display_accuracy(model, images_train[:images_to_display], text="Train Output")
display_accuracy(model, images_test[:images_to_display], text="Prediction Output")
=================================== Train Output ===============================
=================================== Prediction Output ===============================

Code value - Intermediate representation of image¶

In [16]:
from tensorflow import keras

# Layers to be used
layers = [tf.keras.layers.InputLayer(input_shape=images_shape)]
layers.extend(model.layers[:6])

model_code_generator = keras.Sequential(layers)
model_code_generator.build((None, images_shape[0], images_shape[1], images_shape[2]))

for layer in model_code_generator.layers:
  if list(filter(lambda x: x in layer.name, ['flatten', 'reshape'])):
    continue
  assert all([np.array_equal(layer.get_weights()[0], model.get_layer(layer.name).get_weights()[0]), 
              np.array_equal(layer.get_weights()[1], model.get_layer(layer.name).get_weights()[1])]),  "%s weights not same" % layer.name

model_code_generator.summary()
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
encoder_layer_1 (Conv2D)     (None, 32, 32, 128)       3584      
_________________________________________________________________
encoder_layer_2 (Conv2D)     (None, 16, 16, 64)        73792     
_________________________________________________________________
encoder_layer_3 (Conv2D)     (None, 8, 8, 32)          18464     
_________________________________________________________________
encoder_layer_4 (Conv2D)     (None, 4, 4, 16)          4624      
_________________________________________________________________
flatten_5 (Flatten)          (None, 256)               0         
_________________________________________________________________
code (Dense)                 (None, 8)                 2056      
=================================================================
Total params: 102,520
Trainable params: 102,520
Non-trainable params: 0
_________________________________________________________________
In [17]:
codes = model_code_generator.predict(images_test[:16], batch_size=8, verbose=False)
codes.shape
Out[17]:
(16, 8)
In [18]:
print(codes[0].tolist())
print(codes[1].tolist())
print(codes[2].tolist())
[-90.84349060058594, 78.14624786376953, 61.19371795654297, 47.44902038574219, 135.2177734375, 48.00750732421875, -166.69300842285156, -8.807793617248535]
[-82.30125427246094, 82.98831939697266, 61.37909698486328, 43.42222213745117, 97.01995849609375, 58.46400833129883, -44.247249603271484, 2.3804984092712402]
[-55.21834945678711, 40.70209503173828, 36.956214904785156, 42.93532180786133, 128.26377868652344, 51.0125846862793, -89.1243667602539, -20.50321388244629]
In [19]:
code_stats = { 
    "min" : np.min(codes), 
    "max" : np.max(codes), 
    "mean": np.mean(codes),
    "std": np.std(codes)
}
code_stats
Out[19]:
{'max': 136.47034, 'mean': 27.564598, 'min': -187.67128, 'std': 69.36067}

Lets generate some random images¶

But we need to remove some extra layers before that, now we know that code layer has 8 neurons. So we are going to generate some random 8 numbers and will pass it to out decoder layer

In [23]:
import tensorflow as tf
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_cnn.h5'
model = tf.keras.models.load_model(model_file) # Load entire model
# model.summary()
In [25]:
from tensorflow import keras
model_generator = keras.Sequential(model.layers[6:])
model_generator.build((None, 8))
model_generator.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
reshape_8 (Reshape)          (None, 2, 2, 2)           0         
_________________________________________________________________
decoder_layer_1 (Conv2DTrans (None, 4, 4, 16)          304       
_________________________________________________________________
decoder_layer_2 (Conv2DTrans (None, 8, 8, 32)          4640      
_________________________________________________________________
decoder_layer_3 (Conv2DTrans (None, 16, 16, 64)        18496     
_________________________________________________________________
decoder_layer_4 (Conv2DTrans (None, 32, 32, 128)       73856     
_________________________________________________________________
decoder_layer_6 (Conv2DTrans (None, 64, 64, 3)         3459      
=================================================================
Total params: 100,755
Trainable params: 100,755
Non-trainable params: 0
_________________________________________________________________
In [55]:
import numpy as np
inputs  = np.random.normal(code_stats['mean'], code_stats['std'], (16, 8))
# inputs = codes
image_generated = model_generator.predict(inputs, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
plot_images(image_generated, n_col=8)